// -*- mode:c++ -*-

// Copyright (c) 2022 PLCT Lab
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met: redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer;
// redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution;
// neither the name of the copyright holders nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


let {{

def VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                base_class, postacc_code='', declare_template_base=VMemMacroDeclare,
                decode_template=BasicDecode, exec_template_base='',
                vfof=False, is_macroop=True):
    # If it's a macroop, the corresponding microops will be
    # generated.
    # Make sure flags are in lists (convert to lists if not).
    base_type = base_class
    base_class += 'MacroInst'

    mem_flags = makeList(mem_flags)
    inst_flags = makeList(inst_flags)
    iop = InstObjParams(name, Name, base_class,
        {'ea_code': ea_code,
         'memacc_code': memacc_code,
         'postacc_code': postacc_code },
        inst_flags)

    constructTemplate = eval(exec_template_base + 'Constructor')

    header_output   = declare_template_base.subst(iop)
    decoder_output  = ''
    if declare_template_base is not VMemTemplateMacroDeclare:
        decoder_output  += constructTemplate.subst(iop)
    else:
        header_output   += constructTemplate.subst(iop)
    decode_block    = decode_template.subst(iop)
    exec_output     = ''
    if not is_macroop:
        return (header_output, decoder_output, decode_block, exec_output)

    vfof_set_code = ''
    vfof_zero_idx_check_code = ''
    vfof_get_code = ''
    if vfof:
        vfof_set_code = '''
        auto base_addr = Rs1;
        uint32_t elem_size = (eew / 8);
        if (fault != NoFault) {
            auto addr_fault = dynamic_cast<AddressFault*>(fault.get());
            assert(addr_fault != nullptr);
            if (addr_fault) {
                auto fault_addr = addr_fault->trap_value();
                assert(fault_addr >= EA);
                *fault_elem_idx = (fault_addr - base_addr) / elem_size;
            }
        } else {
            *fault_elem_idx = mem_size / elem_size;
        }
        '''
        vfof_zero_idx_check_code = '''
        if (fault != NoFault && *fault_elem_idx == 0)
            return fault;
        '''
        vfof_get_code = '''
        Vfof_ud[0] = *fault_elem_idx;
        '''

    vecWhole = base_type == 'VlWhole' or base_type == 'VsWhole'
    vecIndex = base_type == 'VlIndex' or base_type == 'VsIndex'
    vecSMask = exec_template_base == 'Vlm' or exec_template_base == 'Vsm'

    if vecWhole:
        calc_memsize_code = 'VLENB'
        is_vecWhole = 'true'
        wb_elem_mask = 'true'
    else:
        calc_memsize_code = 'calc_memsize(vmi.rs, vmi.re, eew_, rVl)'
        is_vecWhole = 'false'
        wb_elem_mask = '(ei < rVl) && (this->vm || elem_mask(v0, ei))'


    if vecIndex:
        is_vecIndex = 'true';
    else:
        is_vecIndex = 'false';

    if vecSMask:
        divrVl = 'rVl = (rVl + 7) / 8'
    else:
        divrVl = ''

    microiop = InstObjParams(name + '_micro',
        Name + 'Micro',
        exec_template_base + 'MicroInst',
        {
            'ea_code'       : ea_code,
            'memacc_code'   : memacc_code,
            'postacc_code'  : postacc_code,

            'vfof_set_code' : vfof_set_code,
            'vfof_zero_idx_check_code' : vfof_zero_idx_check_code,
            'vfof_get_code' : vfof_get_code,

            'calc_memsize_code' : calc_memsize_code,
            'is_vecWhole'   : is_vecWhole,
            'is_vecIndex'   : is_vecIndex,

            'wb_elem_mask'  : wb_elem_mask,

            'divrVl'        : divrVl
        },
        inst_flags)

    if mem_flags:
        mem_flags = [ 'Request::%s' % flag for flag in mem_flags ]
        s = '\n\tmemAccessFlags = ' + '|'.join(mem_flags) + ';'
        microiop.constructor += s

    if base_type.startswith('Vl'):
        exe_template_base  = exec_template_base
        init_template_base = 'Vload'
        comp_template_base = 'Vload'
    else:
        exe_template_base  = 'Vstore'
        init_template_base = 'Vstore'
        comp_template_base = 'Vstore'

    microDeclTemplate = eval(exec_template_base + 'Micro' + 'Declare')
    microExecTemplate = eval(exe_template_base + 'Micro' + 'Execute')
    microInitTemplate = eval(init_template_base + 'Micro' + 'InitiateAcc')
    microCompTemplate = eval(comp_template_base + 'Micro' + 'CompleteAcc')
    header_output = microDeclTemplate.subst(microiop) + header_output
    micro_exec_output = (microExecTemplate.subst(microiop) +
        microInitTemplate.subst(microiop) +
        microCompTemplate.subst(microiop))
    if declare_template_base is not VMemTemplateMacroDeclare:
        exec_output += micro_exec_output
    else:
        header_output += micro_exec_output

    return (header_output, decoder_output, decode_block, exec_output)

}};

def format VleOp(
    memacc_code,
    ea_code={{ EA = Rs1 + vmi.offset; }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                'Vle', exec_template_base='Vle')
}};

def format VleffOp(
    memacc_code,
    ea_code={{ EA = Rs1 + VLENB * microIdx; }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                'Vle', exec_template_base='Vleff', vfof=True)
}};

def format VlmOp(
    memacc_code,
    ea_code={{ EA = Rs1; }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                'Vle', exec_template_base='Vlm')
}};

def format VlWholeOp(
    memacc_code,
    ea_code={{ EA = Rs1 + VLENB * microIdx; }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                'VlWhole', exec_template_base='VlWhole')
}};

def format VlStrideOp(
    memacc_code,
    ea_code={{ EA = Rs1 + Rs2*vmi.rs + vmi.offset; }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                'VlStride', exec_template_base='VlStride')
}};

def format VlIndexOp(
    memacc_code,
    ea_code,
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                'VlIndex', declare_template_base=VMemTemplateMacroDeclare,
                decode_template=VMemTemplateDecodeBlock, exec_template_base='VlIndex')
}};

def format VseOp(
    memacc_code,
    ea_code={{ EA = Rs1 + vmi.offset; }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                'Vse', exec_template_base='Vse')
}};

def format VsmOp(
  memacc_code,
  ea_code={{ EA = Rs1; }},
  mem_flags=[],
  inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                'Vse', exec_template_base='Vsm')
}};

def format VsWholeOp(
    memacc_code,
    ea_code={{ EA = Rs1 + VLENB * microIdx; }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                'VsWhole', exec_template_base='VsWhole')
}};

def format VsStrideOp(
    memacc_code,
    ea_code={{ EA = Rs1 + Rs2*vmi.rs + vmi.offset; }},
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                'VsStride', exec_template_base='VsStride')
}};

def format VsIndexOp(
    memacc_code,
    ea_code,
    mem_flags=[],
    inst_flags=[]
) {{
    (header_output, decoder_output, decode_block, exec_output) = \
        VMemBase(name, Name, ea_code, memacc_code, mem_flags, inst_flags,
                'VsIndex', declare_template_base=VMemTemplateMacroDeclare,
                decode_template=VMemTemplateDecodeBlock, exec_template_base='VsIndex')
}};
